{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aP2Rnk_e2QwM"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "cellView": "form",
        "id": "p8XCrIru1-dZ"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f5uhKzJJ5xJA"
      },
      "source": [
        "# X-Ray Smoker Analysis\n",
        "\n",
        "## Overview\n",
        "\n",
        "A simple AI tool that checks chest X-ray images to predict if the person is likely a smoker or non-smoker, based on visible lung patterns and signs using MedGemma\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/MedGemma/[MedGemma]Inference_using_HuggingFace.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iC0-3bMg2V98"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the MedGemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tR2F9oKa5vQu"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "Run the cell below to install all the required dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "Q-hF7kzr2QIf"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.2/40.2 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.0/67.0 MB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.5/10.5 MB\u001b[0m \u001b[31m111.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m3.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m102.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m76.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m63.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m4.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m2.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m23.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install --upgrade --quiet accelerate bitsandbytes transformers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "qe0vxem_2rjH"
      },
      "outputs": [],
      "source": [
        "import requests\n",
        "from PIL import Image\n",
        "from io import BytesIO\n",
        "import torch\n",
        "from transformers import pipeline, BitsAndBytesConfig"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wz5GtmFC6KsS"
      },
      "source": [
        "## Load the model.\n",
        "\n",
        "Initializing a pre-trained MedGemma-4b-it model using the BitsandBytes quantization configuration"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "c1KKG8Gh3DjR"
      },
      "outputs": [],
      "source": [
        "model_id = \"google/medgemma-4b-it\"\n",
        "task = \"image-text-to-text\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "2gE3tnvW2uWB"
      },
      "outputs": [],
      "source": [
        "quantization_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_compute_dtype=\"float16\",\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_quant_type=\"nf4\"\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "jhs_vfWv3L5s"
      },
      "outputs": [],
      "source": [
        "model_kwargs = dict(torch_dtype=torch.bfloat16,\n",
        "                    device_map=\"cuda:0\",\n",
        "                    quantization_config = quantization_config\n",
        "                )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "ZxofPzVA3CIl"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a9589d5626f349058d6e85c943cb332e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2ff2e62d11634afaa61c9fb51340c500",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0204c9e43c434bf59aba6d82d8a6ee25",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c2acd796c7074fd0887f65adb2e7a076",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "50d2b02d217a422bb3f803007853d443",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "54c4383b6434459c95be7af3fa070b79",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0f2f9c7764d34256a3d48857ad677c59",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "391e4b898e724505a78a8e6485aed735",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b90461b8d9914341ab4dbcd18565e893",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "911bba251e6246ab8557488d52cea3e8",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a763f78e20b34c10bdd979e37eaf82ff",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8db24ce4b712467ba89a9e5f171455c4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7fd9452db0fd4d5ab9f58b5ce754bf77",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d874a42020ee443592fad4c364ff283e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4b24e5f585f445f08112e012123add26",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Device set to use cuda:0\n"
          ]
        }
      ],
      "source": [
        "pipe = pipeline(task, model=model_id, model_kwargs=model_kwargs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "zFolbnsV3QLw"
      },
      "outputs": [],
      "source": [
        "pipe.model.generation_config.do_sample = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kEcYUju96jzw"
      },
      "source": [
        "## Get Image\n",
        "\n",
        "You can either directly define the image path i.e., image_url = \"/path/image.png\" or convert into the bytesIO format."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "-tXRG5Pc49Ng"
      },
      "outputs": [],
      "source": [
        "image_url = \"https://www.clinicaladvisor.com/wp-content/uploads/sites/11/2018/12/lungs_0913newsline_451310.jpg\"\n",
        "\n",
        "headers = {\n",
        "    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',\n",
        "}\n",
        "response = requests.get(image_url, headers=headers)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "fEzjb4N2493G"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "200"
            ]
          },
          "execution_count": 23,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "response.status_code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 24,
      "metadata": {
        "id": "wVPcVPYC4eeg"
      },
      "outputs": [],
      "source": [
        "image = Image.open(BytesIO(response.content))\n",
        "\n",
        "if image.mode != 'RGB':\n",
        "    image = image.convert('RGB')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X5rJ-tQ06HLe"
      },
      "source": [
        "## Define the Prompt Template"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "oVQ5ACku4XMv"
      },
      "outputs": [],
      "source": [
        "query = \"\"\"Based on the attached chest X-ray,\n",
        "can you analyze and tell me whether the patient is more likely a smoker or a non-smoker?\"\"\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 25,
      "metadata": {
        "id": "xKMAwJmn3Sdw"
      },
      "outputs": [],
      "source": [
        "system = \"\"\"\n",
        "You are a medical imaging expert AI trained to analyze chest X-rays and identify potential indicators of smoking history.\n",
        "Given a chest X-ray image, you must determine whether the scan is more likely to belong to a smoker or non-smoker based on clinical radiological signs such as:\n",
        "\n",
        "- Lung hyperinflation\n",
        "- Bullae or blebs\n",
        "- Increased lung lucency\n",
        "- Irregular reticular or nodular patterns\n",
        "- Upper lobe cystic changes\n",
        "- Emphysematous changes or interstitial thickening\n",
        "\n",
        "Always include a step-by-step explanation based on observed patterns in the image, and mention the level of confidence (e.g., high, moderate, low).\n",
        "Be cautious not to make a definitive diagnosis but rather offer a likelihood-based reasoning.\n",
        "\"\"\"\n",
        "\n",
        "messages = [\n",
        "    {\n",
        "        \"role\": \"system\",\n",
        "        \"content\": [{\"type\": \"text\", \"text\": system}]\n",
        "    },\n",
        "    {\n",
        "        \"role\": \"user\",\n",
        "        \"content\": [\n",
        "            {\"type\": \"text\", \"text\": query},\n",
        "            {\"type\": \"image\", \"image\": image}\n",
        "        ]\n",
        "    }\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bBujJq2K6w12"
      },
      "source": [
        "## Execute the Prompt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bucjxtfT3Y2X"
      },
      "outputs": [],
      "source": [
        "output = pipe(text=messages, max_new_tokens=2048)\n",
        "response = output[0][\"generated_text\"][-1][\"content\"]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 27,
      "metadata": {
        "id": "xVkm_Tqb5nys"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Okay, I will analyze the chest X-ray and provide a likelihood assessment based on the provided clinical radiological signs.\n",
            "\n",
            "**Analysis:**\n",
            "\n",
            "1.  **Lung Hyperinflation:** The image shows a relatively hyperinflated appearance of the lungs, which can be a sign of emphysema, a condition often associated with smoking.\n",
            "2.  **Increased Lung Lucency:** The lungs appear more translucent or \"darker\" than normal, indicating increased air in the lungs. This is a common finding in emphysema.\n",
            "3.  **Bullae or Blebs:** There are some rounded, air-filled structures visible in the lungs, which could be bullae or blebs. These are also associated with emphysema.\n",
            "4.  **Emphysematous Changes or Interstitial Thickening:** The overall appearance suggests some degree of emphysema, which can be associated with smoking.\n",
            "\n",
            "**Conclusion:**\n",
            "\n",
            "Based on the observed signs of lung hyperinflation, increased lung lucency, and possible bullae/blebs, the image is more likely to belong to a smoker.\n",
            "\n",
            "**Confidence Level:** Moderate. While the findings are suggestive of smoking-related lung disease, further evaluation and correlation with the patient's history are necessary for a definitive conclusion.\n",
            "\n"
          ]
        }
      ],
      "source": [
        "print(response)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[MedGemma]Inference_using_HuggingFace.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
